from __future__ import division
from __future__ import print_function
import os
import glob
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import pickle
import sys
import datetime
import matplotlib
import csv
matplotlib.use('Agg')
import matplotlib.pyplot as plt

os.environ['KMP_DUPLICATE_LIB_OK']='True'


class ARGS:
    def __init__(self):
        self.cuda = True
        self.seed = int(time.time())
args=ARGS()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)



data_path = "C:\\Users\\main_dir_name\\Documents\\or_data\\process\\process\\demand_hr.csv"
plt_path = "C:\\Users\\main_dir_name\\\Documents\\or\\POF_APP1_log\\"


class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # x: input of shape (batch_size, seq_length, input_size)
        
        # Pass through LSTM
        lstm_out, (h_n, c_n) = self.lstm(x)  # lstm_out: (batch_size, seq_length, hidden_size)
        
        # We use the output of the last time step (seq_length - 1) for prediction
        last_hidden_state = lstm_out[:, -1, :]
        
        # Pass through the fully connected layer
        output = self.fc(last_hidden_state)
        
        return output




TOTAL_TRAIN_data_array = []
TOTAL_TRAIN_label_data_array = []


total_kky_count=0
n_intest=8760
nn_output_array_all=[]
label_batch_all = []
grad_array_acc_print = [np.zeros([n_intest,2]),np.zeros([n_intest,2]),np.zeros([n_intest,2]),np.zeros([n_intest,2])]
nsample = 0
network_fixed_size = 0
nclass = 2
optimizer=[]
model=[]


def load_dataset():
    global TOTAL_TRAIN_data_array, TOTAL_TRAIN_label_data_array, model, optimizer
    data = []
    with open(data_path, mode='r', newline='') as file:
        csv_reader = csv.reader(file)
        
        for row in csv_reader:
            data.append([float(x) if '.' in x else int(x) for x in row])
    TOTAL_TRAIN_data_array=[]
    TOTAL_TRAIN_label_data_array=[]
    for j in range(4):
        tmp=[]
        lab_tmp=[]
        for i in range(len(data)-24):
            tmp2=[]
            for k in range(24):
                tmp2.append([data[i+k][j]])
            tmp.append(tmp2)
            lab_tmp.append(data[i+24][j])
        TOTAL_TRAIN_data_array.append(tmp)
        TOTAL_TRAIN_label_data_array.append(lab_tmp)
        model.append(LSTMModel(1, 64, nclass+1).cuda())
        optimizer.append(optim.Adam(model[j].parameters()))
        optimizer[j].zero_grad()
    TOTAL_TRAIN_data_array=np.array(TOTAL_TRAIN_data_array)
    TOTAL_TRAIN_label_data_array = np.array(TOTAL_TRAIN_label_data_array)
    print(TOTAL_TRAIN_data_array)
    print(TOTAL_TRAIN_label_data_array)
    
    
    
    return [len(data)-24]



def set_nnparameter(dropout_request, alpha_request, lr_request, weight_decay_request):
    global model, optimizer



def update_nndata(intest_index_array, sample_index, mode):
    global model, nn_output_array_all, label_batch_all, nsample, network_fixed_size, nclass
    
    nsample = 0
    print("           NN:: the index array length of current training sample: ", len(sample_index))
    label_batch_all=[]
    nn_output_array_all=[]
    for j in range(4):
        if mode == 1:
            model[j].train()
        elif mode == 2:
            model[j].eval()
        else: pass
        data_batch = [TOTAL_TRAIN_data_array[j][i] for i in sample_index]
        data_batch = torch.from_numpy(np.array(data_batch, dtype=np.float32))
        label_batch = [TOTAL_TRAIN_label_data_array[j][i] for i in sample_index]
        label_batch = np.array(label_batch, dtype=np.float32)
        nsample = len(sample_index)
        param = np.array([nclass, nsample])
        
        intest_data = [TOTAL_TRAIN_data_array[j][i] for i in intest_index_array]
        intest_data = torch.from_numpy(np.array(intest_data, dtype=np.float32))
        data_batch = torch.cat([data_batch, intest_data], dim=0)
        intest_label = [TOTAL_TRAIN_label_data_array[j][i] for i in intest_index_array]
        intest_label = np.array(intest_label, dtype=np.float32)
        label_batch = np.concatenate((label_batch.reshape((nsample)), intest_label), axis=0)

        
        data_batch = data_batch.cuda()
        network_fixed_size = nsample+len(intest_index_array)
        nn_output_array = model[j](data_batch)
        nn_output_array_all.append(nn_output_array)
        label_batch_all.append(label_batch)


    label_batch_all=np.array(label_batch_all,dtype=np.float32)
    pReturn1 = param.tolist() 
    pReturn2 = label_batch_all.tolist() 
    pReturn3 = [i.cpu().detach().numpy().astype(np.float32).tolist() for i in nn_output_array_all]
    pReturn = [pReturn1, pReturn2, pReturn3]
    return pReturn  


pof_kky_count = 0
def backpropagation(nn_gradient_list, mode):
    print("           NN:: START BACKPROPAGATION!")
    global model, optimizer, nsample, nn_output_array_all, label_batch_all, grad_array_acc_print, pof_kky_count, total_kky_count
    for j in range(4):
        nn_gradient_array = np.array(nn_gradient_list[j], dtype=np.float32)
        grad_array = torch.from_numpy(np.clip(nn_gradient_array, -100000., 100000.)).cuda()
        grad_array_acc_print[j] += np.clip(nn_gradient_array, -100000., 100000.)[nsample:,1:].reshape((network_fixed_size-nsample), -1) # (n_internal_test_case*n_stock, 1+2*n_class-1)
        if True in grad_array.isnan():
            print("           NN:: NAN in GRAD tensor")
            return
        if True in nn_output_array_all[j].isnan():
            print("           NN:: NAN in OUTPUT tensor")
            return 
        loss = torch.sum(torch.mul(grad_array, nn_output_array_all[j]))+torch.sum(torch.mul(nn_output_array_all[j][:,1:], nn_output_array_all[j][:,1:]))*0.1
        loss.backward()
        optimizer[j].step()
        optimizer[j].zero_grad()
    print("           NN:: COMPLETE BACKPROPAGATION!")
    total_kky_count+=1
    pof_kky_count+=1
        

